{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MiOo3478pJg9"
   },
   "source": [
    "# Before we start...\n",
    "\n",
    "This colab notebook is a minimum demo for faceswap-GAN v2.2. Since colab allows maximum run time limit of 12 hrs, we will only train a lightweight model in this notebook. **The purpose of this notebook is not to train a model that produces high quality results but a quick overview for how faceswap-GAN works.**\n",
    "\n",
    "The pipeline of faceswap-GAN v2.2 is described below:\n",
    "\n",
    "  1. Upload two videos for training.\n",
    "  2. Apply face extraction (preprocessing) on the two uploaded videos\n",
    "  3. Train a liteweight faceswap-GAN model. (This will take 10 ~ 12 hrs)\n",
    "  4. Apply video conversion to the uploaded videos."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Qf8Wvh5Wrm3V"
   },
   "source": [
    "# Step 1: Set runtime type to Python 3/GPU\n",
    "Set the colab notebook to GPU instance through: **runtime -> change runtime type -> Python3 and GPU**\n",
    "\n",
    "The following cells will show the system information of the current instance. Run the cells and check if it uses python >= 3.6 and has a GPU device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "1zQhhPfKrmQl",
    "outputId": "96b19dbd-91a1-4f32-b5a3-54bd8461b674"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.6.3\n"
     ]
    }
   ],
   "source": [
    "import platform\n",
    "print(platform.python_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 272
    },
    "colab_type": "code",
    "id": "1f96C1TApC00",
    "outputId": "eeeeac99-ae05-4aac-a237-77025a68eb4f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[name: \"/device:CPU:0\"\n",
       " device_type: \"CPU\"\n",
       " memory_limit: 268435456\n",
       " locality {\n",
       " }\n",
       " incarnation: 13774914876410535605, name: \"/device:GPU:0\"\n",
       " device_type: \"GPU\"\n",
       " memory_limit: 11281989632\n",
       " locality {\n",
       "   bus_id: 1\n",
       "   links {\n",
       "   }\n",
       " }\n",
       " incarnation: 15095074001980503882\n",
       " physical_device_desc: \"device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7\"]"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensorflow.python.client import device_lib\n",
    "device_lib.list_local_devices()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XOsTN_gjzonZ"
   },
   "source": [
    "# Step 2: Git clone faceswap-GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "id": "L4VioT5YpJQB",
    "outputId": "1824cbe3-1e3b-40c0-967f-27a30530fcb0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'faceswap-GAN'...\n",
      "remote: Counting objects: 776, done.\u001b[K\n",
      "remote: Compressing objects: 100% (66/66), done.\u001b[K\n",
      "remote: Total 776 (delta 41), reused 32 (delta 13), pack-reused 689\u001b[K\n",
      "Receiving objects: 100% (776/776), 2.15 MiB | 2.03 MiB/s, done.\n",
      "Resolving deltas: 100% (464/464), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/shaoanlu/faceswap-GAN.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "g_4C8o3y0DFD",
    "outputId": "73c9e73b-c5b2-40fe-a96c-1add0e4c70f2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/faceswap-GAN\n"
     ]
    }
   ],
   "source": [
    "%cd \"faceswap-GAN\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "K0wsds0OslRc"
   },
   "source": [
    "# Step 3: Upload training videos\n",
    "\n",
    "The user should upload two videos: **source video** and **target video**. The model will **tranform source face to target face by default.**\n",
    "\n",
    "  - The videos better **contain only one person**.\n",
    "  - There is no limitation on video length but the longer it is, the longer preprocessing time / video conversion time it will take, which may cause excceded run time of 12 hrs. (**Recommended video length: 30 secs ~ 2 mins.**)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EQ0HH6VUpJHW"
   },
   "outputs": [],
   "source": [
    "from google.colab import files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "thlzp88qpJJk"
   },
   "outputs": [],
   "source": [
    "# Upload source video\n",
    "source_video = files.upload()\n",
    "\n",
    "for fn_source_video, _ in source_video.items():\n",
    "    print(fn_source_video)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YyKu2hcCpJKu"
   },
   "outputs": [],
   "source": [
    "# Upload target video\n",
    "target_video = files.upload()\n",
    "\n",
    "for fn_target_video, _ in target_video.items():\n",
    "    print(fn_target_video)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 4: Set maximum training iterations\n",
    "Default 25000 iters require ~ 10hrs of training.\n",
    "\n",
    "Iterations >= 27k may exceed run time limit; Iterations < 18k may yield poorly-trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global TOTAL_ITERS\n",
    "TOTAL_ITERS = 34000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PS25Uu9kxDwo"
   },
   "source": [
    "# Step 5: Everything is ready.\n",
    "\n",
    "**Press Ctrl + F10 (or runtime -> run after)** to start the remaining process and leave this page alone. It will take 10 ~ 12 hours to finish training. The result video can be downloaded by running the last cell: \n",
    "  ```python\n",
    "  files.download(\"OUTPUT_VIDEO.mp4\")\n",
    "  # Some browsers do not support this line (e.g., Opera does not pop up a save dialog). Please use Firefox or Chrome.\n",
    "  ```\n",
    "Notice that **this page should not be closed or refreshed while running**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BY3qysVq0p2P"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install moviepy\n",
    "!pip install keras_vggface\n",
    "import imageio\n",
    "imageio.plugins.ffmpeg.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ym0EsJk9pJRw"
   },
   "outputs": [],
   "source": [
    "import keras.backend as K\n",
    "from detector.face_detector import MTCNNFaceDetector\n",
    "import glob\n",
    "\n",
    "from preprocess import preprocess_video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "otuzS3Di0gvz"
   },
   "outputs": [],
   "source": [
    "fd = MTCNNFaceDetector(sess=K.get_session(), model_path=\"./mtcnn_weights/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Kt5FVEt11B2K"
   },
   "outputs": [],
   "source": [
    "!mkdir -p faceA/rgb\n",
    "!mkdir -p faceA/binary_mask\n",
    "!mkdir -p faceB/rgb\n",
    "!mkdir -p faceB/binary_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dO11aRsZ0gyK"
   },
   "outputs": [],
   "source": [
    "save_interval = 5 # perform face detection every {save_interval} frames\n",
    "save_path = \"./faceA/\"\n",
    "preprocess_video(fn_source_video, fd, save_interval, save_path)\n",
    "save_path = \"./faceB/\"\n",
    "preprocess_video(fn_target_video, fd, save_interval, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qIb9TSMz0g0g"
   },
   "outputs": [],
   "source": [
    "print(str(len(glob.glob(\"faceA/rgb/*.*\"))) + \" face(s) extracted from source video: \" + fn_source_video + \".\")\n",
    "print(str(len(glob.glob(\"faceB/rgb/*.*\"))) + \" face(s) extracted from target video: \" + fn_target_video + \".\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O9L5UW8U3AaQ"
   },
   "source": [
    "## The following cells are from [FaceSwap_GAN_v2.2_train_test.ipynb](https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2.2_train_test.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WtyeXpEc2lDO"
   },
   "source": [
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "MtENXluJ0g2I",
    "outputId": "13491684-fc62-4021-c089-5921ae21df44"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import *\n",
    "import keras.backend as K\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HFAokeU12Vff"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import glob\n",
    "import time\n",
    "import numpy as np\n",
    "from pathlib import PurePath, Path\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EJ1miHVk2ns7"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BjrPsIbZ2ViU"
   },
   "outputs": [],
   "source": [
    "K.set_learning_phase(1)\n",
    "# Number of CPU cores\n",
    "num_cpus = os.cpu_count()\n",
    "\n",
    "# Input/Output resolution\n",
    "RESOLUTION = 64 # 64x64, 128x128, 256x256\n",
    "assert (RESOLUTION % 64) == 0, \"RESOLUTION should be 64, 128, or 256.\"\n",
    "\n",
    "# Batch size\n",
    "batchSize = 4\n",
    "\n",
    "# Use motion blurs (data augmentation)\n",
    "# set True if training data contains images extracted from videos\n",
    "use_da_motion_blur = False \n",
    "\n",
    "# Use eye-aware training\n",
    "# require images generated from prep_binary_masks.ipynb\n",
    "use_bm_eyes = True\n",
    "\n",
    "# Probability of random color matching (data augmentation)\n",
    "prob_random_color_match = 0.5\n",
    "\n",
    "da_config = {\n",
    "    \"prob_random_color_match\": prob_random_color_match,\n",
    "    \"use_da_motion_blur\": use_da_motion_blur,\n",
    "    \"use_bm_eyes\": use_bm_eyes\n",
    "}\n",
    "\n",
    "# Path to training images\n",
    "img_dirA = './faceA/rgb'\n",
    "img_dirB = './faceB/rgb'\n",
    "img_dirA_bm_eyes = \"./faceA/binary_mask\"\n",
    "img_dirB_bm_eyes = \"./faceB/binary_mask\"\n",
    "\n",
    "# Path to saved model weights\n",
    "models_dir = \"./models\"\n",
    "\n",
    "# Architecture configuration\n",
    "arch_config = {}\n",
    "arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)\n",
    "arch_config['use_self_attn'] = True\n",
    "arch_config['norm'] = \"hybrid\" # instancenorm, batchnorm, layernorm, groupnorm, none\n",
    "arch_config['model_capacity'] = \"lite\" # standard, lite\n",
    "\n",
    "# Loss function weights configuration\n",
    "loss_weights = {}\n",
    "loss_weights['w_D'] = 0.1 # Discriminator\n",
    "loss_weights['w_recon'] = 1. # L1 reconstruction loss\n",
    "loss_weights['w_edge'] = 0.1 # edge loss\n",
    "loss_weights['w_eyes'] = 30. # reconstruction and edge loss on eyes area\n",
    "loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1) # perceptual loss (0.003, 0.03, 0.3, 0.3)\n",
    "\n",
    "# Init. loss config.\n",
    "loss_config = {}\n",
    "loss_config[\"gan_training\"] = \"mixup_LSGAN\"\n",
    "loss_config['use_PL'] = False\n",
    "loss_config[\"PL_before_activ\"] = True\n",
    "loss_config['use_mask_hinge_loss'] = False\n",
    "loss_config['m_mask'] = 0.\n",
    "loss_config['lr_factor'] = 1.\n",
    "loss_config['use_cyclic_loss'] = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bSrt2h0K3so3"
   },
   "source": [
    "## Build the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5o2dEVW92Vms"
   },
   "outputs": [],
   "source": [
    "from networks.faceswap_gan_model import FaceswapGANModel\n",
    "from data_loader.data_loader import DataLoader\n",
    "from utils import showG, showG_mask, showG_eyes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yiNsc3N_2VhU"
   },
   "outputs": [],
   "source": [
    "model = FaceswapGANModel(**arch_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!wget https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_resnet50.h5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "qvVr2TqI3jd9",
    "outputId": "0da0a7e2-4f07-4776-9d71-58cd9856dbc3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_resnet50.h5\n",
      "94699520/94694792 [==============================] - 30s 0us/step\n"
     ]
    }
   ],
   "source": [
    "#from keras_vggface.vggface import VGGFace\n",
    "\n",
    "# VGGFace ResNet50\n",
    "#vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))'\n",
    "\n",
    "from colab_demo.vggface_models import RESNET50\n",
    "vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
    "vggface.load_weights(\"rcmalli_vggface_tf_notop_resnet50.h5\")\n",
    "\n",
    "#from keras.applications.resnet50 import ResNet50\n",
    "#vggface = ResNet50(include_top=False, input_shape=(224, 224, 3))\n",
    "\n",
    "#vggface.summary()\n",
    "\n",
    "model.build_pl_model(vggface_model=vggface, before_activ=loss_config[\"PL_before_activ\"])\n",
    "model.build_train_functions(loss_weights=loss_weights, **loss_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GexJ-i7c3vz2"
   },
   "source": [
    "## Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yoCnWQiR3jgd"
   },
   "outputs": [],
   "source": [
    "# Create ./models directory\n",
    "Path(f\"models\").mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bjkSygNT3jjB"
   },
   "outputs": [],
   "source": [
    "# Get filenames\n",
    "train_A = glob.glob(img_dirA+\"/*.*\")\n",
    "train_B = glob.glob(img_dirB+\"/*.*\")\n",
    "\n",
    "train_AnB = train_A + train_B\n",
    "\n",
    "assert len(train_A), \"No image found in \" + str(img_dirA)\n",
    "assert len(train_B), \"No image found in \" + str(img_dirB)\n",
    "print (\"Number of images in folder A: \" + str(len(train_A)))\n",
    "print (\"Number of images in folder B: \" + str(len(train_B)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-KzKahMU3jlX"
   },
   "outputs": [],
   "source": [
    "def show_loss_config(loss_config):\n",
    "    for config, value in loss_config.items():\n",
    "        print(f\"{config} = {value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "j5fwrKth3jqO"
   },
   "outputs": [],
   "source": [
    "def reset_session(save_path):\n",
    "    global model, vggface\n",
    "    global train_batchA, train_batchB\n",
    "    model.save_weights(path=save_path)\n",
    "    del model\n",
    "    del vggface\n",
    "    del train_batchA\n",
    "    del train_batchB\n",
    "    K.clear_session()\n",
    "    model = FaceswapGANModel(**arch_config)\n",
    "    model.load_weights(path=save_path)\n",
    "    #vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
    "    vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
    "    vggface.load_weights(\"rcmalli_vggface_tf_notop_resnet50.h5\")\n",
    "    model.build_pl_model(vggface_model=vggface, before_activ=loss_config[\"PL_before_activ\"])\n",
    "    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,\n",
    "                              RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                              RESOLUTION, num_cpus, K.get_session(), **da_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ll7g7mGU3jss"
   },
   "outputs": [],
   "source": [
    "# Start training\n",
    "t0 = time.time()\n",
    "\n",
    "# This try/except is meant to resume training if we disconnected from Colab\n",
    "try:\n",
    "    gen_iterations\n",
    "    print(f\"Resume training from iter {gen_iterations}.\")\n",
    "except:\n",
    "    gen_iterations = 0\n",
    "\n",
    "errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "errGAs = {}\n",
    "errGBs = {}\n",
    "# Dictionaries are ordered in Python 3.6\n",
    "for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "    errGAs[k] = 0\n",
    "    errGBs[k] = 0\n",
    "\n",
    "display_iters = 300\n",
    "global TOTAL_ITERS\n",
    "\n",
    "global train_batchA, train_batchB\n",
    "train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "\n",
    "while gen_iterations <= TOTAL_ITERS: \n",
    "    \n",
    "    # Loss function automation\n",
    "    if gen_iterations == (TOTAL_ITERS//5 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = False\n",
    "        loss_config['m_mask'] = 0.0\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.5\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Complete.\")\n",
    "    elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.2\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.4\n",
    "        loss_config['lr_factor'] = 0.3\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):\n",
    "        clear_output()\n",
    "        model.decoder_A.load_weights(\"models/decoder_B.h5\") # swap decoders\n",
    "        model.decoder_B.load_weights(\"models/decoder_A.h5\") # swap decoders\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.5\n",
    "        loss_config['lr_factor'] = 1\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (8*TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.1\n",
    "        loss_config['lr_factor'] = 0.3\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = False\n",
    "        loss_config['m_mask'] = 0.0\n",
    "        loss_config['lr_factor'] = 0.1\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    \n",
    "    if gen_iterations == 5:\n",
    "        print (\"working.\")\n",
    "    \n",
    "    # Train dicriminators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)\n",
    "    errDA_sum +=errDA[0]\n",
    "    errDB_sum +=errDB[0]\n",
    "\n",
    "    # Train generators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)\n",
    "    errGA_sum += errGA[0]\n",
    "    errGB_sum += errGB[0]\n",
    "    for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):\n",
    "        errGAs[k] += errGA[i]\n",
    "        errGBs[k] += errGB[i]\n",
    "    gen_iterations+=1\n",
    "    \n",
    "    # Visualization\n",
    "    if gen_iterations % display_iters == 0:\n",
    "        clear_output()\n",
    "            \n",
    "        # Display loss information\n",
    "        show_loss_config(loss_config)\n",
    "        print(\"----------\") \n",
    "        print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'\n",
    "        % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,\n",
    "           errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  \n",
    "        print(\"----------\") \n",
    "        print(\"Generator loss details:\")\n",
    "        print(f'[Adversarial loss]')  \n",
    "        print(f'GA: {errGAs[\"adv\"]/display_iters:.4f} GB: {errGBs[\"adv\"]/display_iters:.4f}')\n",
    "        print(f'[Reconstruction loss]')\n",
    "        print(f'GA: {errGAs[\"recon\"]/display_iters:.4f} GB: {errGBs[\"recon\"]/display_iters:.4f}')\n",
    "        print(f'[Edge loss]')\n",
    "        print(f'GA: {errGAs[\"edge\"]/display_iters:.4f} GB: {errGBs[\"edge\"]/display_iters:.4f}')\n",
    "        if loss_config['use_PL'] == True:\n",
    "            print(f'[Perceptual loss]')\n",
    "            try:\n",
    "                print(f'GA: {errGAs[\"pl\"][0]/display_iters:.4f} GB: {errGBs[\"pl\"][0]/display_iters:.4f}')\n",
    "            except:\n",
    "                print(f'GA: {errGAs[\"pl\"]/display_iters:.4f} GB: {errGBs[\"pl\"]/display_iters:.4f}')\n",
    "        \n",
    "        # Display images\n",
    "        print(\"----------\") \n",
    "        wA, tA, _ = train_batchA.get_next_batch()\n",
    "        wB, tB, _ = train_batchB.get_next_batch()\n",
    "        print(\"Transformed (masked) results:\")\n",
    "        showG(tA, tB, model.path_A, model.path_B, batchSize)   \n",
    "        print(\"Masks:\")\n",
    "        showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  \n",
    "        print(\"Reconstruction results:\")\n",
    "        showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           \n",
    "        errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "        for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "            errGAs[k] = 0\n",
    "            errGBs[k] = 0\n",
    "        \n",
    "        # Save models\n",
    "        model.save_weights(path=models_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kRD_02DL3ju9"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lSqr_UTX4wRE"
   },
   "source": [
    "## The following cells are from [FaceSwap_GAN_v2.2_video_conversion.ipynb](https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2.2_video_conversion.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "x58lrjkw6iQM"
   },
   "source": [
    "## Video conversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EOi2ZsSa4vf4"
   },
   "outputs": [],
   "source": [
    "from converter.video_converter import VideoConverter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global model, vggface\n",
    "global train_batchA, train_batchB\n",
    "del model\n",
    "del vggface\n",
    "del train_batchA\n",
    "del train_batchB\n",
    "tf.reset_default_graph()\n",
    "K.clear_session()\n",
    "model = FaceswapGANModel(**arch_config)\n",
    "model.load_weights(path=models_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "R4uUub3A4vi7"
   },
   "outputs": [],
   "source": [
    "fd = MTCNNFaceDetector(sess=K.get_session(), model_path=\"./mtcnn_weights/\")\n",
    "vc = VideoConverter()\n",
    "vc.set_face_detector(fd)\n",
    "vc.set_gan_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MxwjrVoc4vmQ"
   },
   "outputs": [],
   "source": [
    "options = {\n",
    "    # ===== Fixed =====\n",
    "    \"use_smoothed_bbox\": True,\n",
    "    \"use_kalman_filter\": True,\n",
    "    \"use_auto_downscaling\": False,\n",
    "    \"bbox_moving_avg_coef\": 0.65,\n",
    "    \"min_face_area\": 35 * 35,\n",
    "    \"IMAGE_SHAPE\": model.IMAGE_SHAPE,\n",
    "    # ===== Tunable =====\n",
    "    \"kf_noise_coef\": 1e-3,\n",
    "    \"use_color_correction\": \"hist_match\",\n",
    "    \"detec_threshold\": 0.8,\n",
    "    \"roi_coverage\": 0.9,\n",
    "    \"enhance\": 0.,\n",
    "    \"output_type\": 3,\n",
    "    \"direction\": \"AtoB\", # ==================== This line determines the transform direction ====================\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qyuELtb94vpr"
   },
   "outputs": [],
   "source": [
    "if options[\"direction\"] == \"AtoB\":\n",
    "    input_fn = fn_source_video\n",
    "    output_fn = \"OUTPUT_VIDEO_AtoB.mp4\"\n",
    "elif options[\"direction\"] == \"BtoA\":\n",
    "    input_fn = fn_target_video\n",
    "    output_fn = \"OUTPUT_VIDEO_BtoA.mp4\"\n",
    "\n",
    "duration = None # None or a non-negative float tuple: (start_sec, end_sec). Duration of input video to be converted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Sem7RMzm4vr6"
   },
   "outputs": [],
   "source": [
    "vc.convert(input_fn=input_fn, output_fn=output_fn, options=options, duration=duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QwKOc5Jt5oFO"
   },
   "source": [
    "# Download result video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "M0MkktlHC5sh"
   },
   "outputs": [],
   "source": [
    "from google.colab import files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hwDUxwlL4voW"
   },
   "outputs": [],
   "source": [
    "if options[\"direction\"] == \"AtoB\":\n",
    "    files.download(\"OUTPUT_VIDEO_AtoB.mp4\")\n",
    "elif options[\"direction\"] == \"BtoA\":\n",
    "    files.download(\"OUTPUT_VIDEO_BtoA.mp4\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "faceswap-GAN_lite_demo.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
